from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
from datasets import load_dataset
import torch
from peft import LoraConfig, get_peft_model
import yaml
import json
from tqdm import tqdm


model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, cache_dir = './')
incompatible_state_dict = torch.load('/home/user/llavafinetune/output_model_quark/model_final.pt', map_location='cpu')

lora_config = LoraConfig(
r=16,
lora_alpha=16,
target_modules=["q_proj", 
                "k_proj",
                "v_proj"],
lora_dropout=0.1,
bias="none",
modules_to_save=["classifier"],
)
model = get_peft_model(model, lora_config)

state_dict = {}
for key in incompatible_state_dict.keys():
    new_key = key.split('.model.model')[0] + '.model' + key.split('.model.model')[1]
    state_dict[new_key] = incompatible_state_dict[key]
    
model.load_state_dict(state_dict)
model.to('cuda')
model.eval()

processor = LlavaNextProcessor.from_pretrained('llava-hf/llava-v1.6-mistral-7b-hf', cache_dir = './')

def inference(image, prompt1, prompt2, response1):
    image = Image.open(image)
    template1 = "[INST] <image>\ngood {} [/INST]"
    template2 = "[INST] good {} [/INST]"
    prompt1 = template1.format(prompt1)
    prompt2 = template2.format(prompt2)
    instruction = prompt1 + response1 + prompt2
    inputs = processor(instruction, image, return_tensors="pt")
    inputs = {k: v.cuda() for k, v in inputs.items()}
    with torch.no_grad():
        output = model.generate(**inputs, max_new_tokens=1024)
    answer = processor.decode(output[0], skip_special_tokens=True)
    answer = answer.split('[/INST]')[2].strip()
    return answer




with open('/home/user/llavafinetune/multiturn/gpt4v_generated_feedback.jsonl', 'r') as f:
    multiturn_data = [json.loads(line) for line in f]

results = []
for item in tqdm(multiturn_data):
    attempts = 0
    prompt1 = item['question']
    response1 = item['answer']
    prompt2 = item['roundtwo']
    final_response = inference(item['img_path'], prompt1, prompt2, response1)
    item['final_response'] = final_response
    results.append(item)

with open('/home/user/llavafinetune/multiturn/quark_final_response.jsonl', 'w') as f:
    for result in results:
        f.write(json.dumps(result) + '\n')
    
